import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import torch

pascal_graph = {0:[0],
                1:[1, 2],
                2:[1, 2, 3, 5],
                3:[2, 3, 4],
                4:[3, 4],
                5:[2, 5, 6],
                6:[5, 6]}

cihp_graph = {0: [],
              1: [2, 13],
              2: [1, 13],
              3: [14, 15],
              4: [13],
              5: [6, 7, 9, 10, 11, 12, 14, 15],
              6: [5, 7, 10, 11, 14, 15, 16, 17],
              7: [5, 6, 9, 10, 11, 12, 14, 15],
              8: [16, 17, 18, 19],
              9: [5, 7, 10, 16, 17, 18, 19],
              10:[5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17],
              11:[5, 6, 7, 10, 13],
              12:[5, 7, 10, 16, 17],
              13:[1, 2, 4, 10, 11],
              14:[3, 5, 6, 7, 10],
              15:[3, 5, 6, 7, 10],
              16:[6, 8, 9, 10, 12, 18],
              17:[6, 8, 9, 10, 12, 19],
              18:[8, 9, 16],
              19:[8, 9, 17]}

atr_graph = {0: [],
              1: [2, 11],
              2: [1, 11],
              3: [11],
              4: [5, 6, 7, 11, 14, 15, 17],
              5: [4, 6, 7, 8, 12, 13],
              6: [4,5,7,8,9,10,12,13],
              7: [4,11,12,13,14,15],
              8: [5,6],
              9: [6, 12],
              10:[6, 13],
              11:[1,2,3,4,7,14,15,17],
              12:[5,6,7,9],
              13:[5,6,7,10],
              14:[4,7,11,16],
              15:[4,7,11,16],
              16:[14,15],
              17:[4,11],
              }

cihp2pascal_adj = np.array([[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                              [0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
                              [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])

cihp2pascal_nlp_adj = \
    np.array([[ 1.,  0.35333052,  0.32727194,  0.17418084,  0.18757584,
         0.40608522,  0.37503981,  0.35448462,  0.22598555,  0.23893579,
         0.33064262,  0.28923404,  0.27986573,  0.4211553 ,  0.36915778,
         0.41377746,  0.32485771,  0.37248222,  0.36865639,  0.41500332],
       [ 0.39615879,  0.46201529,  0.52321467,  0.30826114,  0.25669527,
         0.54747773,  0.3670523 ,  0.3901983 ,  0.27519473,  0.3433325 ,
         0.52728509,  0.32771333,  0.34819325,  0.63882953,  0.68042925,
         0.69368576,  0.63395791,  0.65344337,  0.59538781,  0.6071375 ],
       [ 0.16373166,  0.21663339,  0.3053872 ,  0.28377612,  0.1372435 ,
         0.4448808 ,  0.29479995,  0.31092595,  0.22703953,  0.33983576,
         0.75778818,  0.2619818 ,  0.37069392,  0.35184867,  0.49877512,
         0.49979437,  0.51853277,  0.52517541,  0.32517741,  0.32377309],
       [ 0.32687232,  0.38482461,  0.37693463,  0.41610834,  0.20415749,
         0.76749079,  0.35139853,  0.3787411 ,  0.28411737,  0.35155421,
         0.58792618,  0.31141718,  0.40585111,  0.51189218,  0.82042737,
         0.8342413 ,  0.70732188,  0.72752501,  0.60327325,  0.61431337],
       [ 0.34069369,  0.34817292,  0.37525998,  0.36497069,  0.17841617,
         0.69746208,  0.31731463,  0.34628951,  0.25167277,  0.32072379,
         0.56711286,  0.24894776,  0.37000453,  0.52600859,  0.82483993,
         0.84966274,  0.7033991 ,  0.73449378,  0.56649608,  0.58888791],
       [ 0.28477487,  0.35139564,  0.42742352,  0.41664321,  0.20004676,
         0.78566833,  0.42237487,  0.41048549,  0.37933812,  0.46542516,
         0.62444759,  0.3274493 ,  0.49466009,  0.49314658,  0.71244233,
         0.71497003,  0.8234787 ,  0.83566589,  0.62597135,  0.62626812],
       [ 0.3011378 ,  0.31775977,  0.42922647,  0.36896257,  0.17597556,
         0.72214655,  0.39162804,  0.38137872,  0.34980296,  0.43818419,
         0.60879174,  0.26762545,  0.46271161,  0.51150476,  0.72318109,
         0.73678399,  0.82620388,  0.84942166,  0.5943811 ,  0.60607602]])

pascal2atr_nlp_adj = \
    np.array([[ 1.,  0.35333052,  0.32727194,  0.18757584,  0.40608522,
         0.27986573,  0.23893579,  0.27600672,  0.30964391,  0.36865639,
         0.41500332,  0.4211553 ,  0.32485771,  0.37248222,  0.36915778,
         0.41377746,  0.32006291,  0.28923404],
       [ 0.39615879,  0.46201529,  0.52321467,  0.25669527,  0.54747773,
         0.34819325,  0.3433325 ,  0.26603942,  0.45162929,  0.59538781,
         0.6071375 ,  0.63882953,  0.63395791,  0.65344337,  0.68042925,
         0.69368576,  0.44354613,  0.32771333],
       [ 0.16373166,  0.21663339,  0.3053872 ,  0.1372435 ,  0.4448808 ,
         0.37069392,  0.33983576,  0.26563416,  0.35443504,  0.32517741,
         0.32377309,  0.35184867,  0.51853277,  0.52517541,  0.49877512,
         0.49979437,  0.21750868,  0.2619818 ],
       [ 0.32687232,  0.38482461,  0.37693463,  0.20415749,  0.76749079,
         0.40585111,  0.35155421,  0.28271333,  0.52684576,  0.60327325,
         0.61431337,  0.51189218,  0.70732188,  0.72752501,  0.82042737,
         0.8342413 ,  0.40137029,  0.31141718],
       [ 0.34069369,  0.34817292,  0.37525998,  0.17841617,  0.69746208,
         0.37000453,  0.32072379,  0.27268885,  0.47426719,  0.56649608,
         0.58888791,  0.52600859,  0.7033991 ,  0.73449378,  0.82483993,
         0.84966274,  0.37830796,  0.24894776],
       [ 0.28477487,  0.35139564,  0.42742352,  0.20004676,  0.78566833,
         0.49466009,  0.46542516,  0.32662614,  0.55780359,  0.62597135,
         0.62626812,  0.49314658,  0.8234787 ,  0.83566589,  0.71244233,
         0.71497003,  0.41223219,  0.3274493 ],
       [ 0.3011378 ,  0.31775977,  0.42922647,  0.17597556,  0.72214655,
         0.46271161,  0.43818419,  0.3192333 ,  0.50979216,  0.5943811 ,
         0.60607602,  0.51150476,  0.82620388,  0.84942166,  0.72318109,
         0.73678399,  0.39259827,  0.26762545]])

cihp2atr_nlp_adj = np.array([[ 1.,  0.35333052,  0.32727194,  0.18757584,  0.40608522,
         0.27986573,  0.23893579,  0.27600672,  0.30964391,  0.36865639,
         0.41500332,  0.4211553 ,  0.32485771,  0.37248222,  0.36915778,
         0.41377746,  0.32006291,  0.28923404],
       [ 0.35333052,  1.        ,  0.39206695,  0.42143438,  0.4736689 ,
         0.47139544,  0.51999208,  0.38354847,  0.45628529,  0.46514124,
         0.50083501,  0.4310595 ,  0.39371443,  0.4319752 ,  0.42938598,
         0.46384034,  0.44833757,  0.6153155 ],
       [ 0.32727194,  0.39206695,  1.        ,  0.32836702,  0.52603065,
         0.39543695,  0.3622627 ,  0.43575346,  0.33866223,  0.45202552,
         0.48421   ,  0.53669903,  0.47266611,  0.50925436,  0.42286557,
         0.45403656,  0.37221304,  0.40999322],
       [ 0.17418084,  0.46892601,  0.25774838,  0.31816231,  0.39330317,
         0.34218382,  0.48253904,  0.22084125,  0.41335728,  0.52437572,
         0.5191713 ,  0.33576117,  0.44230914,  0.44250678,  0.44330833,
         0.43887264,  0.50693611,  0.39278795],
       [ 0.18757584,  0.42143438,  0.32836702,  1.        ,  0.35030067,
         0.30110947,  0.41055555,  0.34338879,  0.34336307,  0.37704433,
         0.38810141,  0.34702081,  0.24171562,  0.25433078,  0.24696241,
         0.2570884 ,  0.4465962 ,  0.45263213],
       [ 0.40608522,  0.4736689 ,  0.52603065,  0.35030067,  1.        ,
         0.54372584,  0.58300258,  0.56674191,  0.555266  ,  0.66599594,
         0.68567555,  0.55716359,  0.62997328,  0.65638548,  0.61219615,
         0.63183318,  0.54464151,  0.44293752],
       [ 0.37503981,  0.50675565,  0.4761106 ,  0.37561813,  0.60419403,
         0.77912403,  0.64595517,  0.85939662,  0.46037144,  0.52348817,
         0.55875094,  0.37741886,  0.455671  ,  0.49434392,  0.38479954,
         0.41804074,  0.47285709,  0.57236283],
       [ 0.35448462,  0.50576632,  0.51030446,  0.35841033,  0.55106903,
         0.50257274,  0.52591451,  0.4283053 ,  0.39991808,  0.42327211,
         0.42853819,  0.42071825,  0.41240559,  0.42259136,  0.38125352,
         0.3868255 ,  0.47604934,  0.51811717],
       [ 0.22598555,  0.5053299 ,  0.36301185,  0.38002282,  0.49700941,
         0.45625243,  0.62876479,  0.4112051 ,  0.33944371,  0.48322639,
         0.50318714,  0.29207815,  0.38801966,  0.41119094,  0.29199072,
         0.31021029,  0.41594871,  0.54961962],
       [ 0.23893579,  0.51999208,  0.3622627 ,  0.41055555,  0.58300258,
         0.68874251,  1.        ,  0.56977937,  0.49918447,  0.48484363,
         0.51615925,  0.41222306,  0.49535971,  0.53134951,  0.3807616 ,
         0.41050298,  0.48675801,  0.51112664],
       [ 0.33064262,  0.306412  ,  0.60679935,  0.25592294,  0.58738706,
         0.40379627,  0.39679161,  0.33618385,  0.39235148,  0.45474013,
         0.4648476 ,  0.59306762,  0.58976007,  0.60778661,  0.55400397,
         0.56551297,  0.3698029 ,  0.33860535],
       [ 0.28923404,  0.6153155 ,  0.40999322,  0.45263213,  0.44293752,
         0.60359359,  0.51112664,  0.46578181,  0.45656936,  0.38142307,
         0.38525582,  0.33327223,  0.35360175,  0.36156453,  0.3384992 ,
         0.34261229,  0.49297863,  1.        ],
       [ 0.27986573,  0.47139544,  0.39543695,  0.30110947,  0.54372584,
         1.        ,  0.68874251,  0.67765588,  0.48690078,  0.44010641,
         0.44921156,  0.32321099,  0.48311542,  0.4982002 ,  0.39378102,
         0.40297733,  0.45309735,  0.60359359],
       [ 0.4211553 ,  0.4310595 ,  0.53669903,  0.34702081,  0.55716359,
         0.32321099,  0.41222306,  0.25721705,  0.36633509,  0.5397475 ,
         0.56429928,  1.        ,  0.55796926,  0.58842844,  0.57930828,
         0.60410597,  0.41615326,  0.33327223],
       [ 0.36915778,  0.42938598,  0.42286557,  0.24696241,  0.61219615,
         0.39378102,  0.3807616 ,  0.28089866,  0.48450394,  0.77400821,
         0.68813814,  0.57930828,  0.8856886 ,  0.81673412,  1.        ,
         0.92279623,  0.46969152,  0.3384992 ],
       [ 0.41377746,  0.46384034,  0.45403656,  0.2570884 ,  0.63183318,
         0.40297733,  0.41050298,  0.332879  ,  0.48799542,  0.69231828,
         0.77015091,  0.60410597,  0.79788484,  0.88232104,  0.92279623,
         1.        ,  0.45685017,  0.34261229],
       [ 0.32485771,  0.39371443,  0.47266611,  0.24171562,  0.62997328,
         0.48311542,  0.49535971,  0.32477932,  0.51486622,  0.79353556,
         0.69768738,  0.55796926,  1.        ,  0.92373745,  0.8856886 ,
         0.79788484,  0.47883134,  0.35360175],
       [ 0.37248222,  0.4319752 ,  0.50925436,  0.25433078,  0.65638548,
         0.4982002 ,  0.53134951,  0.38057074,  0.52403969,  0.72035243,
         0.78711147,  0.58842844,  0.92373745,  1.        ,  0.81673412,
         0.88232104,  0.47109935,  0.36156453],
       [ 0.36865639,  0.46514124,  0.45202552,  0.37704433,  0.66599594,
         0.44010641,  0.48484363,  0.39636574,  0.50175258,  1.        ,
         0.91320249,  0.5397475 ,  0.79353556,  0.72035243,  0.77400821,
         0.69231828,  0.59087008,  0.38142307],
       [ 0.41500332,  0.50083501,  0.48421,  0.38810141,  0.68567555,
         0.44921156,  0.51615925,  0.45156472,  0.50438158,  0.91320249,
         1.,  0.56429928,  0.69768738,  0.78711147,  0.68813814,
         0.77015091,  0.57698754,  0.38525582]])



def normalize_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()

def preprocess_adj(adj):
    """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(adj)) # return a adjacency matrix of adj ( type is numpy)
    adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) #
    # return sparse_to_tuple(adj_normalized)
    return adj_normalized.todense()

def row_norm(inputs):
    outputs = []
    for x in inputs:
        xsum = x.sum()
        x = x / xsum
        outputs.append(x)
    return outputs


def normalize_adj_torch(adj):
    # print(adj.size())
    if len(adj.size()) == 4:
        new_r = torch.zeros(adj.size()).type_as(adj)
        for i in range(adj.size(1)):
            adj_item = adj[0,i]
            rowsum = adj_item.sum(1)
            d_inv_sqrt = rowsum.pow_(-0.5)
            d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
            d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
            r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt)
            new_r[0,i,...] = r
        return new_r
    rowsum = adj.sum(1)
    d_inv_sqrt = rowsum.pow_(-0.5)
    d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
    d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
    r = torch.matmul(torch.matmul(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)
    return r

# def row_norm(adj):




if __name__ == '__main__':
    a= row_norm(cihp2pascal_adj)
    print(a)
    print(cihp2pascal_adj)
    # print(a.shape)
